import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
from CNN import CNN
"""下载数据集"""
train_data = dataset.MNIST(
    root = 'mnist',
    train = True,
    transform = transforms.ToTensor(),
    download = True
)

test_data = dataset.MNIST(
    root = 'mnist',
    train = False,
    transform = transforms.ToTensor(),
    download = True
)



train_loader = data_utils.DataLoader(
    dataset = train_data,
    batch_size = 64,
    shuffle = True
)
tset_loader = data_utils.DataLoader(
    dataset = test_data,
    batch_size = 64,
    shuffle = True
)


cnn = CNN();
cnn = cnn.cuda();

for index,(images,labels) in enumerate(train_loader):
#print(index)
#print(images)
#print(labels)
    images = images.cuda()
    labels = labels.cuda()
    outputs = cnn(images)
    break;
